Source code for hysop.symbolic.relational

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sympy as sm

from hysop.symbolic import Expr
from hysop.tools.htypes import first_not_None
from hysop.tools.numpywrappers import npw


[docs] class NAryRelation(Expr): """ Represents relations bewteen n variables. Parameters ---------- args: tuple of Expr """ @property def rel_op(self): return None def __new__(cls, *exprs): obj = super().__new__(cls, *exprs) return obj def __str__(self): rel_op = f" {self.rel_op} " return f"({rel_op.join(str(x) for x in self.args)})" def _sympystr(self, printer): rel_op = f" {self.rel_op} " return "({})".format(rel_op.join(f"{printer._print(x)}" for x in self.args)) def _ccode(self, printer): rel_op = f" {self.rel_op} " return "({})".format(rel_op.join(f"{printer._print(x)}" for x in self.args)) @property def is_number(self): return True @property def free_symbols(self): return ()
[docs] class LogicalRelation(NAryRelation): pass
[docs] class ArithmeticRelation(NAryRelation): pass
[docs] class Add(ArithmeticRelation): @property def rel_op(self): return "+"
[docs] class Mul(ArithmeticRelation): @property def rel_op(self): return "*"
[docs] class Pow(ArithmeticRelation): @property def rel_op(self): return "**"
[docs] class LogicalAND(LogicalRelation): @property def rel_op(self): return "&&"
[docs] class LogicalOR(LogicalRelation): @property def rel_op(self): return "||"
[docs] class LogicalXOR(LogicalRelation): @property def rel_op(self): return "^"
[docs] class LogicalEQ(LogicalRelation): @property def rel_op(self): return "=="
[docs] class LogicalNE(LogicalRelation): @property def rel_op(self): return "!="
[docs] class LogicalLT(LogicalRelation): @property def rel_op(self): return "<"
[docs] class LogicalGT(LogicalRelation): @property def rel_op(self): return ">"
[docs] class LogicalLE(LogicalRelation): @property def rel_op(self): return "<="
[docs] class LogicalGE(LogicalRelation): @property def rel_op(self): return ">="
[docs] class BinaryRelation(NAryRelation): """ Represents relations bewteen 2 variables. Parameters ---------- lhs : Expr rhs : Expr """ def __new__(cls, lhs, rhs): obj = super().__new__(cls, lhs, rhs) obj.lhs = lhs obj.rhs = rhs return obj
[docs] class Assignment(BinaryRelation): """ Represents variable assignment for code generation. Parameters ---------- lhs : Expr rhs : Expr """ def __str__(self): lhs = first_not_None(getattr(self.lhs, "name", None), self.lhs) rhs = first_not_None(getattr(self.rhs, "name", None), self.rhs) rel_op = self.rel_op if rel_op == "=": rel_op = ":" + rel_op return "{} {} {};".format( lhs, rel_op, sm.printing.str.StrPrinter()._print(self.rhs) ) def _ccode(self, printer): try: return self.lhs.declare(init=printer._print(self.rhs)) except: return "{} {} {};".format( printer._print(self.lhs), self.rel_op, printer._print(self.rhs) ) @property def rel_op(self): return "="
[docs] @classmethod def assign(cls, lhs, rhs, skip_zero_rhs=False): exprs = () def create_expr(rhs): return (not skip_zero_rhs) or (rhs != 0) if isinstance(lhs, npw.ndarray) and isinstance(rhs, npw.ndarray): assert isinstance(lhs, npw.ndarray), type(lhs) assert isinstance(rhs, npw.ndarray), type(rhs) assert rhs.size == lhs.size assert rhs.shape == lhs.shape for l, r in zip(lhs.ravel().tolist(), rhs.ravel().tolist()): if create_expr(r): e = cls(l, r) exprs += (e,) elif isinstance(lhs, npw.ndarray) or isinstance(rhs, npw.ndarray): if isinstance(lhs, npw.ndarray): lhss = lhs.ravel().tolist() rhss = (rhs,) * len(lhss) else: rhss = rhs.ravel().tolist() lhss = (lhs,) * len(rhss) for l, r in zip(lhss, rhss): if create_expr(r): e = cls(l, r) exprs += (e,) elif isinstance(lhs, sm.Basic) and isinstance(rhs, sm.Basic): assert isinstance(lhs, sm.Basic), type(lhs) assert isinstance(rhs, sm.Basic), type(rhs) e = cls(lhs, rhs) if create_expr(rhs): exprs += (e,) else: msg = "Cannot handle operand types:\n *lhs: {}\n *rhs: {}\n" msg = msg.format(type(lhs), type(rhs)) raise TypeError(msg) return exprs
[docs] class AugmentedAssignment(Assignment): """ Base class for augmented assignments """ @property def rel_op(self): return self._symbol + "="
[docs] class AddAugmentedAssignment(AugmentedAssignment): _symbol = "+"
[docs] class SubAugmentedAssignment(AugmentedAssignment): _symbol = "-"
[docs] class MulAugmentedAssignment(AugmentedAssignment): _symbol = "*"
[docs] class DivAugmentedAssignment(AugmentedAssignment): _symbol = "/"
[docs] class ModAugmentedAssignment(AugmentedAssignment): _symbol = "%"
[docs] class NAryFunction(Expr): """ Represents relations bewteen n variables. Parameters ---------- args: tuple of Expr """ @property def fname(self): raise NotImplemented def __new__(cls, *exprs): obj = super().__new__(cls, *exprs) return obj def __str__(self): return "{}({})".format(self.fname, ", ".join(str(x) for x in self.args)) def _sympystr(self, printer): return "{}({})".format( self.fname, ", ".join(f"{printer._print(x)}" for x in self.args) ) def _ccode(self, printer): return "{}({})".format( self.fname, ", ".join(f"{printer._print(x)}" for x in self.args) ) @property def is_number(self): return True @property def free_symbols(self): return ()
[docs] class UnaryFunction(NAryFunction): def __new__(cls, a): return super().__new__(cls, a)
[docs] class BinaryFunction(NAryFunction): def __new__(cls, lhs, rhs): return super().__new__(cls, lhs, rhs)
[docs] class Max(BinaryFunction): @property def fname(self): return "max"
[docs] class Min(BinaryFunction): @property def fname(self): return "min"
[docs] class Round(UnaryFunction): @property def fname(self): return "round"